# ------------------------------------------------------------------------------
# Plots adverse outcomes (short-term) by stent-hat
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 02_plot_adverse_outcomes.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(ggplot2)

u <- modules::use(here("lib", "util.R"))
a <- modules::use(here("lib", "aesthetics.R"))
a$get_font("Optima", here::here("lib", "optima.ttf"))
temp <- here(
  "code", "03_score_diagnostics", "temp",
  "02_plot_adverse_outcomes"
)
if (!dir.exists(temp)) {
  dir.create(temp)
}

# Helpers ----------------------------------------------------------------------
get_mean_outcomes <- function(subsample, population, outcome, x_var, ...) {
  sub_df <- cohort[cohort[[subsample]], ]
  mean_outcomes <- u$get_mean_outcomes(
    population, outcome, x_var, sub_df
  )
  return(mean_outcomes)
}

get_font_size <- function(subsample, outcome, ...) {
  large_font <- subsample == "noecg" | subsample == "not_sameday_tn" |
    subsample == "not_admit_sym" |
    !(grepl("macetrop", outcome) & grepl("death", outcome))

  return(large_font)
}

add_labels <- function(
                       gg, outcome, x_var, mean_outcomes, subsample, incl_titles = FALSE, ...) {
  ylabel <- case_when(
    outcome == "death_030_day" ~ "Death Rate",
    outcome == "death_060_day" ~ "Death Rate (60 day)",
    outcome == "death_180_day" ~ "Death Rate (180 day)",
    outcome == "death_365_day" ~ "Death Rate (365 day)",
    grepl("macetrop", outcome) & !grepl("death", outcome) ~ "Diagnosed Event Rate",
    grepl("macetrop", outcome) & grepl("death", outcome) ~ "Adverse Event Rate",
    outcome == "untested_noecg" ~ "Fraction Untested and Without ECG",
    outcome == "untested_not_sameday_tn" ~ "Fraction Untested and Without a Troponin Test",
    TRUE ~ outcome
  )
  xlabel <- "Percentile of Predicted Risk"
  title_lab <- case_when(
    outcome == "death_030_day" ~ "Death (30 day)",
    outcome == "death_060_day" ~ "Death (60 day)",
    outcome == "death_180_day" ~ "Death (180 day)",
    outcome == "death_365_day" ~ "Death (365 day)",
    outcome == "macetrop_030_pos" ~ "Diagnosed Event (Tn > 0)",
    outcome == "macetrop_030_0.05" ~ "Diagnosed Event (Tn > 0.05)",
    outcome == "macetrop_030_0.1" ~ "Diagnosed Event (Tn > 0.1)",
    outcome == "macetrop_030_0.5" ~ "Diagnosed (Tn > 0.5)",
    outcome == "macetrop_pos_or_death_030" ~ "Adverse Event (Tn > 0)",
    outcome == "macetrop_0.05_or_death_030" ~ "Adverse Event(Tn > 0.05)",
    outcome == "macetrop_0.1_or_death_030" ~ "Adverse Event (Tn > 0.1)",
    outcome == "macetrop_0.5_or_death_030" ~ "Adverse Event (Tn > 0.5)",
    outcome == "untested_noecg" ~ "Fraction Untested and Without ECG",
    outcome == "untested_not_sameday_tn" ~ "Fraction Untested and Without Troponin",

    TRUE ~ outcome
  )
  title_lab <- glue("{title_lab} by Predicted Risk")

  if (grepl("untested_", outcome)) {
    ymin <- 0
    ymax <- 1
    y_vals <- seq(ymin, ymax, 0.25)
  } else if(
    outcome == "macetrop_pos_or_death_030" & subsample == "full" &
    x_var == "tile_stent_or_cabg_010_tested"
  ){
    ymax <- 0.16
    ymin <- 0
    y_vals <- seq(ymin, ymax, 0.02)
  } else {
    ymax <- 0.1
    ymin <- 0
    y_vals <- seq(ymin, ymax, 0.02)
  }

  gg <- gg + labs(x = xlabel, y = ylabel) +
    coord_cartesian(ylim = c(ymin, ymax)) +
    scale_y_continuous(breaks = y_vals, labels = y_vals)

  if (incl_titles) {
    gg <- gg + labs(title = title, subtitle = subsample)
  }

  if (grepl("macetrop", outcome) & grepl("death", outcome)) {
    gg <- gg +
      geom_hline(
        yintercept = 0.02, color = "black", alpha = 1, linetype = "dashed",
        size = 0.5
      ) +
      annotate(
        "label",
        x = 2, y = 0.02, label = "Upper Bound: Clinical Risk Threshold",
        family = "Optima", size = 12
      )
  }

  return(gg)
}

get_filename <- function(outcome, x_var, subsample, ...) {
  x_var_lab <- case_when(
    grepl("tile_stent_or_cabg", x_var) ~ parse_number(x_var) %>%
      str_pad(3, pad = "0") %>%
      {
        glue("tile_{.}")
      },
    TRUE ~ x_var
  )
  outcome_lab <- str_remove(outcome, "\\.")
  fn <- (file.path(temp, glue("{outcome_lab}__by__{x_var_lab}__for__{subsample}.png")))
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
overnight_lab <- ""
subsample_flags <- readRDS(paths$analysis$subsample_flags)
troponin <- readRDS(paths$analysis$troponin)
cohort <- readRDS(glue(paths$analysis$test_cohort)) %>%
  filter(!exclude) %>%
  u$safe_left_join(troponin) %>%
  u$safe_left_join(subsample_flags) %>%
  mutate(
    untested_noecg = !test_010_day & noecg,
    untested_not_sameday_tn = !test_010_day & not_sameday_tn
  )

# Outcomes for Tn levels -------------------------------------------------------
message("Building adverse outcome vars using different tn cutoffs...")
tn_thresholds <- c("pos", "0.05", "0.1", "0.5")
for (tn_level in tn_thresholds) {
  message("Tn: ", tn_level)
  window <- 30
  window_pd <- str_pad(window, width = 3, side = "left", pad = "0")
  maxtrop_var <- glue("max_trop_{window_pd}_{tn_level}")
  death_var <- glue("death_{window_pd}_day")
  macetrop_var <- glue("macetrop_{window_pd}_{tn_level}")
  macetrop_death_var <- glue("macetrop_{tn_level}_or_death_{window_pd}")

  df <- copy(cohort) %>%
    setnames(c(maxtrop_var, death_var), c("macetrop", "death")) %>%
    mutate(macetrop_or_death = macetrop | death) %>%
    select(ed_enc_id, macetrop, macetrop_or_death) %>%
    setnames(
      c("macetrop", "macetrop_or_death"),
      c(macetrop_var, macetrop_death_var)
    )
  cohort <- u$safe_left_join(cohort, df)
}

# Outcome Config ---------------------------------------------------------------
message("Preparing config table...")
tri_palette <- a$cont_ramp(10)[c(3, 2, 1)]
death_horizons <- c("030")

config <- crossing(
  subsample = c("full", "noecg", "not_sameday_tn", "not_admit_sym"),
  outcome = c(
    glue("macetrop_030_{tn_thresholds}"), glue("death_{death_horizons}_day"),
    glue("macetrop_{tn_thresholds}_or_death_030"),
    "untested_noecg", "untested_not_sameday_tn"
  ),
  x_var = c("tile_stent_or_cabg_010_tested", "tile_stent_or_cabg_004_tested")
) %>%
  filter(
    (grepl("pos", outcome) | grepl("death", outcome) | subsample == "full") &
      (subsample == "full" | x_var == "tile_stent_or_cabg_004_tested") &
      (grepl("or_death", outcome) | x_var == "tile_stent_or_cabg_004_tested") &
      (
        !grepl("untested_", outcome) | (
          subsample == "full" & x_var == "tile_stent_or_cabg_004_tested"
        )
      )
  ) %>%
  mutate(palette = rep(tri_palette, length = nrow(.))) %>%
  mutate(
    palette = case_when(
      grepl("untested_", outcome) ~ tri_palette[1],
      grepl("010", x_var) ~ tri_palette[3],
      subsample != "full" ~ tri_palette[3],
      TRUE ~ palette
    )
  )

# Plotting Means ---------------------------------------------------------------
message("Plotting adverse outcome rates...")
plots <- config %>%
  mutate(population = pmap(., u$get_population) %>% unlist()) %>%
  mutate(mean_outcomes = pmap(., get_mean_outcomes)) %>%
  mutate(large_font = pmap(., get_font_size)) %>%
  mutate(gg = pmap(., a$tile_plot)) %>%
  mutate(plot = pmap(., add_labels)) %>%
  # use this line to include titles in ggplot
  # mutate(plot = pmap(., add_labels, incl_titles = T)) %>%
  mutate(filename = pmap(., get_filename))

# Save -------------------------------------------------------------------------
message("Saving...")
saves <- plots %>%
  select(plot, filename) %>%
  pmap(ggsave, width = 10, height = 7, units = "in")

message("Done.")
